# implementation of the PB2-Mix and PB2 Baseline
# This uses the original GPy implementation instead of gpytorch

import logging
import os
import shutil
import pandas as pd
import GPy
from hpo.hpo_base import *
from collections import OrderedDict
import random
import copy
import numpy as np
from .pb2_utils import *
from copy import deepcopy
from hpo.utils import get_dim_info, get_reward_from_trajectory, get_start_point
import ConfigSpace.hyperparameters as CSH
from .exp3 import exp3_get_cat
from hpo.casmo.bgpbt import BGPBT
from hpo.casmo.casmo import Casmo4RL
from hpo.utils import is_large


class PB2(BGPBT):

    def __init__(self, env, log_dir,
                 max_timesteps: int = None,
                 pop_size: int = 4,
                 n_init: int = None,
                 verbose: bool = False,
                 ard=False,
                 t_ready: int = None,
                 quantile_fraction: float = .25,
                 # joint_nas_hpo: bool = False,
                 seed: int = None,
                 use_reward_fraction: float = 0.,
                 method='random',
                 # cat_exp='cocabo',
                 num_rounds: int = 10,
                 existing_policy: str = 'resume',
                 backtrack: bool = True,
                 # distillation parameters
                 n_distillation_timesteps: int = int(3e7),
                 distillation: bool = False,
                 patience: int = 15,
                 distill_every: int = int(3e6),
                 arch_policy: str = 'search',
                 t_ready_end: int = None,
                 ):
        super(PB2, self).__init__(env, log_dir=log_dir,
                                  max_timesteps=max_timesteps, pop_size=pop_size, n_init=n_init, verbose=verbose,
                                  ard=ard, t_ready=t_ready, quantile_fraction=quantile_fraction, seed=seed,
                                  use_reward_fraction=use_reward_fraction, existing_policy=existing_policy, backtrack=backtrack,
                                  t_ready_end=t_ready_end)
        # assert 0 < quantile_fraction <= .5
        # assert int(quantile_fraction * pop_size) >= 1, 'quantile_fraction * pop size must be >= 1!'
        # assert 0 < t_ready_start < max_timesteps
        # # assert joint_nas_hpo is False, 'joint_nas_hpo is not yet supported in PB2Mix'

        self.perturb_amount = [0.8, 1.2]
        self.categorical_prob = 0.5
        self.numRounds = num_rounds
        assert method in ['pbt', 'random', 'pb2', 'pb2mix']

        self.method = method
        # self.cat_exp = cat_exp
        # get an internal representation of the configspace
        names = []
        types = []
        ranges = []
        n_cat = 0
        self.one_hot_dims = []
        self.one_hot_dim_indices = []
        for i, param_name in enumerate(self.env.config_space):
            if type(self.env.config_space[param_name]) == CSH.CategoricalHyperparameter:
                choices = list(self.env.config_space[param_name].choices)
                if self.method == 'pb2':
                    # one-hot transform!
                    for j, choice in enumerate(choices):
                        names.append(f'{param_name}_{choice}')
                        types.append('continuous')
                        ranges.append([0., 1.])
                    self.one_hot_dims.append([f'{param_name}_{choice}' for choice in choices])
                else:
                    names.append(param_name)
                    types.append('categorical')
                    ranges.append(choices)
                n_cat += 1
            else:
                names.append(param_name)
                types.append('continuous')
                # ranges.append([self.env.config_space[param_name].lower, self.env.config_space[param_name].upper])
                ranges.append([0., 1.])
        self.mutations = pd.DataFrame({
            'Name': names,
            'Type': types,
            'Range': ranges})
        if n_cat == 0 and self.method == 'pb2mix':
            logging.info('Method=pb2mix but there is no categorical dimensions in the search space. Fall back to PB2')
            self.method = 'pb2'
        if len(self.one_hot_dims):
            for group in self.one_hot_dims:
                self.one_hot_dim_indices.append([names.index(c) for c in group])

        # distillation settings -- only if distillation is enabled
        self.distillation = distillation
        self.n_distillation_timesteps = n_distillation_timesteps
        self.patience = patience
        self.distill_every = distill_every
        assert arch_policy in ['static', 'schedule', 'search']
        self.arch_policy = arch_policy
        if self.resumed:
            self.last_distill_timestep = self.df[self.df.n_distills == self.n_distills][self.budget_type].min()
            self.best_loss = self.df[self.df.n_distills == self.n_distills].R.min()
            best_t = self.df[(self.df.n_distills == self.n_distills) * (self.df.R == self.best_loss)]['t'].iloc[-1]
            self.n_fail = self.df.t.max() - best_t
            if self.arch_policy == 'schedule':
                self.policy_net = [32] * min(5, self.n_distills + 2)
                self.value_net = [256] * min(6, self.n_distills + 3)
            elif self.arch_policy == 'static':
                self.policy_net = [32] * 4
                self.value_net = [256] * 5
            else:  # search
                self.policy_net = eval(self.df.policy_net.iloc[-1])
                self.value_net = eval(self.df.value_net.iloc[-1])
            logging.info(f'Resumed from existing: Last distill={self.last_distill_timestep}. n_fail={self.n_fail}. '
                         f'n_distills={self.n_distills}. Last timestep={self.df[self.budget_type].max()}.'
                         f' Current policy net = {self.policy_net},'
                         f'Current value net={self.value_net}.')
        else:
            self.last_distill_timestep = 0
            self.n_fail = 0
            self.best_loss = float('inf')
            if self.arch_policy == 'schedule':
                self.policy_net = [32] * 2
                self.value_net = [256] * 3
            elif self.arch_policy == 'static':
                self.policy_net = [32] * 4
                self.value_net = [256] * 5

    def search_init(self):
        if self.distillation:
            return self._search_init_distill()
        else:
            return self._search_init(), self.policy_net, self.value_net

    def _search_init(self):
        """Searching for good initialising points, for the case of no distillation (normal PBT)"""
        init_random_configs = [self.env.config_space.sample_configuration() for _ in range(self.n_init)]
        self.init_idx = 0

        if self.n_init < 10000  or self.method in ['pbt', 'random']:        # todo
            init_configs = init_random_configs[:self.n_init]
            init_store_paths = [os.path.join(f'{self.log_dir}/pb2_checkpoints',
                                             f'{self.env.env_name}_seed{self.env.seed}_InitConfig{i}_Stage{self.n_distills}.pt')
                                for i in range(len(init_configs))]
            self.init_results_tmp = self.env.train_batch(configs=init_configs, seeds=[self.seed] * len(init_configs),
                                                         nums_timesteps=[self.t_ready_start] * len(init_configs),
                                                         checkpoint_paths=init_store_paths,
                                                         policy_hidden_layer_sizes=self.policy_net,
                                                         v_hidden_layer_sizes=self.value_net,
                                                         )
            costs = [-get_reward_from_trajectory(r['y'], self.use_reward, 0.) for r in self.init_results_tmp]
        else:
            # use a mix of random and BO exploration (the choice of BO surrogate depends on the method: PB2: normal GP, PB2Mix: CoCaBO
            # define objective function and search spaces in accordance to the CoCaBO API
            self.init_results_tmp = []

            def f(configs):
                """The objective function handle to be passed to the init_bo"""
                ckpt_paths = [os.path.join(f'{self.log_dir}/pb2_checkpoints',
                                           f'{self.env.env_name}_seed{self.env.seed}_InitConfig{i}_Stage{self.n_distills}.pt')
                              for i in range(self.init_idx, self.init_idx + len(configs))]
                n_large_unrolls = sum([c['unroll_length'] > 20 for c in configs])
                max_parallel = self.env.max_parallel // 2 if n_large_unrolls / len(
                    configs) > 0.5 else self.env.max_parallel
                trajectories = self.env.train_batch(configs=configs, seeds=[self.seed] * len(configs),
                                                    nums_timesteps=[self.t_ready_start] * len(configs),
                                                    max_parallel=max_parallel,
                                                    checkpoint_paths=ckpt_paths,
                                                    policy_hidden_layer_sizes=self.policy_net,
                                                    v_hidden_layer_sizes=self.value_net,
                                                    )
                self.init_results_tmp += trajectories
                reward = [-get_reward_from_trajectory(np.array(t['y']), use_last_fraction=self.use_reward) for t
                          in trajectories]
                self.init_idx += len(configs)
                return reward

            # this is just an end-to-end BO (not PBT)
            init_bo = Casmo4RL(self.env, self.log_dir, max_iters=self.n_init, max_timesteps=self.t_ready_start,
                               batch_size=min(8, self.env.max_parallel), n_init=len(self.env.config_space), ard=False,
                               obj_func=f, use_standard_gp=True)
            init_configs, costs = init_bo.run()

        if self.n_init <= self.pop_size:
            top_config_ids = np.arange(len(init_configs)).tolist()
        else:  # we sample more configs (BO or random sampling), and only start PBT using the best of those.
            # using the ``pop_size'' best as the initialising population
            top_config_ids = np.argpartition(np.array(costs), self.pop_size)[:self.pop_size].tolist()
        new_pop = deepcopy(self.pop)
        for i, config_id in enumerate(top_config_ids):
            new_pop[i] = {
                'done': False,
                'config': init_configs[config_id],
                'path': os.path.join(f'{self.log_dir}/pb2_checkpoints',
                                     f'{self.env.env_name}_seed{self.env.seed}_Agent{i}.pt'),
                'config_source': 'random',
                'excluded': False,
            }
            if self.env.env_name not in ['dummy', 'synthetic']:
                shutil.copy(os.path.join(f'{self.log_dir}/pb2_checkpoints',
                                         f'{self.env.env_name}_seed{self.env.seed}_InitConfig{config_id}_Stage{self.n_distills}.pt'),
                            os.path.join(f'{self.log_dir}/pb2_checkpoints',
                                         f'{self.env.env_name}_seed{self.env.seed}_Agent{i}.pt'))
        # update the dataframe with these data (including those bad-performing points not selected in the initial population)
        t = 1 if self.n_distills == 0 else self.df.t.max() + 1
        for i in range(len(init_configs)):
            config = init_configs[i]
            conf_array = config.get_array().tolist()
            final_cost = costs[i]
            rl_reward = get_reward_from_trajectory(self.init_results_tmp[i]['y'], 1, 0.)
            scalar_steps = self.init_results_tmp[i]['x'][-1]
            d = pd.DataFrame(columns=self.df.columns)
            if i in top_config_ids:
                agent_number = top_config_ids.index(i)
                path = self.pop[agent_number]['path'] if agent_number >= 0 else np.nan
                d.loc[0] = [agent_number, t, scalar_steps, final_cost, rl_reward, conf_array, path, config, 'random',
                            False,
                            self.n_distills, self.policy_net, self.value_net]
                self.df = pd.concat([self.df, d]).reset_index(drop=True)
                logging.info("\nAgent: {}, Timesteps: {}, Cost: {}\n".format(agent_number, scalar_steps, final_cost, ))
        del self.init_results_tmp, self.init_idx
        return new_pop

    def _search_init_distill(self):
        if self.arch_policy in ['static', 'schedule']:
            if self.arch_policy == 'static':
                policy_net = [32] * 4
                value_net = [256] * 5
            else:
                policy_net = [32] * min(5, self.n_distills + 2)
                value_net = [256] * min(6, self.n_distills + 3)
            pop = self._search_init()
            return pop, policy_net, value_net
        else:
            raise NotImplementedError  # todo

    def run(self, ):
        all_done = False
        distill_at_this_step = False

        # specify the checkpoint path of all agents
        if not self.resumed:
            logging.info('Searching for initialising configurations!')
            self.pop, self.policy_net, self.value_net = self.search_init()
        while not all_done:
            result = {}
            result[self.budget_type] = 0
            # this trains all agents in the population in parallel
            if not distill_at_this_step or len(self.value_net) > 6:
                n_large_models = sum([is_large(c['config']) for c in self.pop.values()])
                max_parallel = self.env.max_parallel // 2 if n_large_models / len(
                    self.pop) > 0.5 else self.env.max_parallel
                logging.info(f'Max parallel for this iteration={max_parallel}. Last distillation = {self.last_distill_timestep}')
                logging.info(f'Running config={[c["config"] for c in self.pop.values()]}')
                if self.t_ready_end == self.t_ready_start or self.t_ready_end is None:
                    t_ready = self.t_ready_start
                else:
                    t_ready = int(
                        self.t_ready_start + (self.t_ready_end - self.t_ready_start) / self.max_timesteps * self.df[
                            self.budget_type].max())
                results_values = self.env.train_batch(configs=[c['config'] for c in self.pop.values()],
                                                      exp_idx_start=0,
                                                      max_parallel=max_parallel,
                                                      seeds=[self.seed] * len(self.pop),
                                                      nums_timesteps=[t_ready] * len(self.pop),
                                                      policy_hidden_layer_sizes=self.policy_net,
                                                      v_hidden_layer_sizes=self.value_net,
                                                      checkpoint_paths=[self.pop[i]['path'] for i in
                                                                        range(len(self.pop))])
            else:
                n = max(int(self.pop_size * self.quantile_fraction), 1)
                logging.info(f'Modifying the nets at this iteration & distilling.')

                last_entries = self.df[
                    self.df['t'] == self.df.t.max()]  # index entire population based on last set of runs
                last_entries = last_entries.iloc[:self.pop_size]  # only want the original entries
                ranked_last_entries = last_entries.sort_values(by=['R'], ignore_index=True,
                                                               ascending=False)  # rank last entries
                best_agents = list(ranked_last_entries.iloc[-n:]['Agent'].values)
                if self.arch_policy in ['static', 'schedule']:
                    new_policy_net = self.policy_net if self.arch_policy == 'static' else self.policy_net + \
                                                                                          [self.policy_net[-1]]
                    new_action_net = self.value_net if self.arch_policy == 'static' else self.value_net + \
                                                                                         [self.value_net[-1]]
                    new_pop = deepcopy(self.pop)
                else:
                    new_pop, new_policy_net, new_action_net = self.search_init()

                best_agent_paths = []
                teacher_configs = []
                for agent in range(len(new_pop)):
                    best_agent = random.sample(best_agents, 1)[0]
                    if self.backtrack:
                        best_agent_path = self.best_checkpoint_dir
                    else:
                        best_agent_path = self.pop[best_agent]['path']
                    if self.arch_policy in ['static', 'schedule']:
                        if np.random.random() < 0.5:
                            new_pop[agent]['config'] = \
                                get_start_point(self.env.config_space, self.pop[best_agent]['config'].get_array(),
                                                return_config=True)[1]
                        else:
                            new_pop[agent]['config'] = self.env.config_space.sample_configuration()
                    best_agent_paths.append(best_agent_path)
                    teacher_configs.append(self.pop[best_agent]['config'])

                logging.info(f'Distilling config={teacher_configs} to {[c["config"] for c in new_pop.values()]}')

                results_values = self.env.distill_batch(teacher_configs=teacher_configs,
                                                        student_configs=[c['config'] for c in new_pop.values()],
                                                        seeds=[self.seed] * len(self.pop),
                                                        distill_nums_timesteps=[self.n_distillation_timesteps] * len(
                                                            self.pop),
                                                        train_nums_timesteps=[self.t_ready_start] * len(self.pop),
                                                        fixed_teacher_params={
                                                            'policy_hidden_layer_sizes': self.policy_net,
                                                            'v_hidden_layer_sizes': self.value_net},
                                                        fixed_student_params={
                                                            'policy_hidden_layer_sizes': new_policy_net,
                                                            'v_hidden_layer_sizes': new_action_net},
                                                        checkpoint_paths=best_agent_paths,
                                                        student_checkpoint_paths=[c['path'] for c in self.pop.values()],
                                                        max_parallel=2)

                self.policy_net = new_policy_net
                self.value_net = new_action_net
                self.pop = new_pop

                distill_at_this_step = False
            results_keys = list(self.pop.keys())
            results = dict(zip(results_keys, results_values))

            if self.df.t.empty:
                t = 1
            else:
                t = self.df.t.max() + 1

            for agent in self.pop.keys():

                if self.pop[agent]['done']:
                    logging.info(f'Skipping completed agent {agent}.')
                    continue

                # negative sign to convert the reward maximization to a minimisation problem
                # final reward is the
                final_cost = -get_reward_from_trajectory(results[agent]['y'], self.use_reward, 0.)
                rl_reward = get_reward_from_trajectory(results[agent]['y'], 1, 0.)
                final_timestep = results[agent]['x'][-1]

                if self.df[self.df['Agent'] == agent].empty:
                    scalar_steps = final_timestep
                else:
                    scalar_steps = final_timestep + self.df[self.df['Agent'] == agent][self.budget_type].max()
                logging.info("\nAgent: {}, Timesteps: {}, Cost: {}\n".format(agent, scalar_steps, final_cost))

                conf_array = self.pop[agent]['config'].get_array().tolist()
                conf = self.pop[agent]['config']
                config_source = self.pop[agent]['config_source']
                d = pd.DataFrame(columns=self.df.columns)
                d.loc[0] = [agent, t, scalar_steps, final_cost, rl_reward, conf_array,
                            self.pop[agent]['path'], conf, config_source, self.pop[agent]['excluded'], self.n_distills,
                            self.policy_net, self.value_net]
                self.df = pd.concat([self.df, d]).reset_index(drop=True)

                if self.df[self.df['Agent'] == agent][self.budget_type].max() >= self.max_timesteps:
                    self.pop[agent]['done'] = True

            if self.backtrack and self.env.env_name not in ['dummy', 'synthetic']:
                best_cost = self.df['R'].min()
                if best_cost < self.best_cost:
                    self.best_cost = best_cost
                    overall_best_agent = self.df[self.df['R'] == best_cost].iloc[-1]
                    shutil.copy(overall_best_agent['path'], self.best_checkpoint_dir)

            # exploitation -- copy the weights and etc.
            for agent in self.pop.keys():
                old_conf = self.pop[agent]['config'].get_array().tolist()
                self.pop[agent], copied = self.exploit(agent, )
                # here we need to include a way to account for changes in the data.
                new_conf = self.pop[agent]['config'].get_array().tolist()
                if not np.isclose(0, np.nansum(np.array(old_conf) - np.array(new_conf))):
                    logging.info("changing conf for agent: {}".format(agent))
                    new_row = self.df[(self.df['Agent'] == copied) & (self.df['t'] == self.df.t.max())]
                    new_row['Agent'] = agent
                    # new_row['path'] = self.pop[agent]['path']
                    logging.info(f"new row conf old: {new_row['conf']}")
                    logging.info(f"new row conf new: {[new_conf]}")
                    new_row['conf'] = [new_conf]
                    new_row['conf_'] = [CS.Configuration(self.env.config_space, vector=new_conf)]
                    self.df = pd.concat([self.df, new_row]).reset_index(drop=True)
                    logging.info(f"new config: {new_conf}")

            all_done = np.array([self.pop[agent]['done'] for agent in self.pop.keys()]).all()
            # save intermediate results
            self.df.to_csv(os.path.join(self.log_dir, f'stats_seed_{self.seed}_intermediate.csv'))
            self.last_distill_timestep = self.df[self.df.n_distills == self.n_distills][
                self.budget_type].min()  # record the timestep as the last time we undergo distillation.
            t_max = self.df[self.budget_type].max()
            best_loss = self.df[self.df['n_distills'] == self.n_distills].R.min()
            if self.df[
                (self.df[self.budget_type] == t_max) & (self.df['n_distills'] == self.n_distills)].R.min() == best_loss:
                self.n_fail = 0
            else:
                self.n_fail += 1
            # restart when the casmo trust region is below threshold
            if self.distillation:
                if self.n_fail >= self.patience or t_max - self.last_distill_timestep > self.distill_every:
                    distill_at_this_step = True
                    self.n_distills += 1
                    self.n_fail = 0
                    self.best_cost = float('inf')
                    logging.info('Start distillation in the next iteration..')
                logging.info(f'n_fail: {self.n_fail}')

        return self.df

    def exploit(self, agent):

        if self.method == 'random':
            random_config = self.env.config_space.sample_configuration()
            # random_config_array = random_config.get_array()
            self.pop[agent]['config'] = random_config
            return self.pop[agent], 0

        eps = 0
        if self.df[self.df['Agent'] == agent].t.empty:
            return self.pop[agent], 0
        else:
            n = max(int(self.pop_size * self.quantile_fraction), 1)
            max_t = self.df.t.max()  # last iteration entry
            last_entries = self.df[self.df['t'] == max_t]  # index entire population based on last set of runs
            last_entries = last_entries.iloc[:self.pop_size]  # only want the original entries
            ranked_last_entries = last_entries.sort_values(by=['R'], ignore_index=True,
                                                           ascending=False)  # rank last entries
            # print(self.df)
            position = list(ranked_last_entries.Agent.values).index(agent) + 1  # not indexed to zero
            if position <= n:
                best_agents = list(ranked_last_entries.iloc[-n:]['Agent'].values)
                best_agent = random.sample(best_agents, 1)[0]

                if self.env.env_name not in ['synthetic', 'dummy']:
                    best_path = self.best_checkpoint_dir if self.backtrack else self.pop[best_agent]['path']
                    current_path = self.pop[agent]['path']
                    shutil.copy(best_path, current_path)

                new_config, eps, source = self.explore(agent, best_agent, self.pop[best_agent]['config'], self.df)
                self.pop[agent]['config'] = new_config
                if self.method == 'pb2mix':
                    # if self.cat_exp == 'cocabo':
                    self.pop[agent]['Eps_cont'] = eps[0]
                    self.pop[agent]['Eps_cat'] = eps[1]
                    self.pop[agent]['config_source'] = source
                else:
                    self.pop[agent]['Eps_cont'] = eps
                    self.pop[agent]['Eps_cat'] = 0
                    self.pop[agent]['config_source'] = source

                logging.info("\n replaced agent {} with agent {}".format(agent, best_agent))
                logging.info(self.pop[agent]['config'])
            else:
                # not exploiting, not exploring... move on :)
                best_agent = copy.copy(agent)
                self.pop[agent]['config_source'] = 'previous'

        return self.pop[agent], best_agent

    def explore(self, agent, best_agent, config, df):

        if self.method == 'pbt':
            eps = 0, 0
            new_config = self.explore_PBT(config)
            source = 'pbt'
            return new_config, eps, source

        elif self.method in ['pb2', 'pb2mix']:
            return self.explore_PB2(agent, best_agent, df)

    def explore_PBT(self, config: CS.Configuration):

        logging.info("\nPBT Explore\n")

        # to_use = []
        cs = self.env.config_space
        # current = convert_to_vec(args, config)
        config = deepcopy(config)
        for i, param_name in enumerate(config):
            param = cs[param_name]
            if type(cs[param_name]) == CSH.CategoricalHyperparameter:
                if np.random.rand() > self.categorical_prob:
                    config[param_name] = np.random.choice(param.choices)
            else:
                old_val = config[param_name]
                lb, ub = cs[param_name].lower, cs[param_name].upper
                new_val = old_val * self.perturb_amount[int(np.random.choice([0, 1]))]
                new_val = np.clip(new_val, lb, ub)
                if type(cs[param_name]) in [CSH.UniformIntegerHyperparameter, CSH.NormalIntegerHyperparameter]:
                    config[param_name] = int(np.round(new_val))
                else:
                    config[param_name] = new_val

        return config

    def format_df(self, agent, copied, df):
        """
        Helper func for PB2 methods.

        Input: args, the agent index, and total df
        Output: dfnewpoint: New fixed params, data: formatted data
        """

        # Get current
        n = max(int(self.pop_size * self.quantile_fraction), 1)
        agent_t = df[df['Agent'] == agent].t.max()  # last iteration entry
        last_entries = df[df['t'] == agent_t]  # index entire population based on last set of runs
        ranked_last_entries = last_entries.sort_values(by=['R'], ignore_index=True)  # rank last entries
        # best_agents = list(ranked_last_entries.iloc[-n:]['Agent'].values)

        not_exploring = list(ranked_last_entries.iloc[:-n]['Agent'].values)
        for a in not_exploring:
            try:
                self.running[str(agent_t)].update(
                    {str(a): df[(df['Agent'] == a) & (df['t'] == agent_t)]['conf'].values[0]})
            except KeyError:
                self.running.update(
                    {str(agent_t): {str(a): df[(df['Agent'] == a) & (df['t'] == agent_t)]['conf'].values[0]}})

        data = df[['Agent', 't', self.budget_type, 'R']]
        data[['x{}'.format(i) for i in range(len(df.conf[0]))]] = pd.DataFrame(df.conf.tolist(), index=df.index)

        data["y"] = data.groupby(["Agent"] + ['x{}'.format(i) for i in range(len(df.conf[0]))])["R"].diff()
        data["t_change"] = data.groupby(["Agent"] + ['x{}'.format(i) for i in range(len(df.conf[0]))])[
            self.budget_type].diff()

        data = data[data["t_change"] > 0].reset_index(drop=True)
        data["R_before"] = data.R - data.y

        data["y"] = data.y / data.t_change
        data = data[~data.y.isna()].reset_index(drop=True)
        data = data.sort_values(by=self.budget_type).reset_index(drop=True)
        data = data.iloc[-1000:, :].reset_index(drop=True)
        dfnewpoint = data[data["Agent"] == copied]
        return dfnewpoint, data, agent_t

    def explore_PB2(self, agent, copied, df):
        df = df.copy()

        logging.info("\nPB2 Explore\n")

        self.cont_vars = ['x{}'.format(i) for idx, i in enumerate(range(len(df.conf[0]))) if
                          self.mutations.Type.values[idx] == 'continuous']
        self.cat_vars = ['x{}'.format(i) for idx, i in enumerate(range(len(df.conf[0]))) if
                         self.mutations.Type.values[idx] == 'categorical']
        self.all_vars = ['x{}'.format(i) for i in range(len(df.conf[0]))]

        dfnewpoint, data, agent_t = self.format_df(agent, copied, df)
        if not dfnewpoint.empty:

            to_use = {'x{}'.format(i): 0 for i in range(len(self.mutations))}

            # select categorical variables first
            for i in range(len(self.mutations)):
                row = self.mutations.iloc[i]

                if row.Type == 'categorical':
                    # if self.cat_exp == 'fixed':
                    #     to_use['x{}'.format(i)] = self.fixed_cat_val
                    if self.method == 'pb2mix':
                        # PB2-Adv/PB2-CoCa
                        data_cat = data.copy()
                        data_cat["y_exp3"] = normalize(data_cat['y'], data_cat['y'])
                        pendingactions = [x[i] for x in self.running[str(agent_t)].values()]
                        cat = exp3_get_cat(row, data_cat, self.numRounds, pendingactions)
                        to_use['x{}'.format(i)] = cat

            y = np.array(data.y.values)

            hparams = data[self.all_vars]
            current = [x for x in self.running[str(agent_t)].values()]
            # else:
            #     raise ValueError
            t_r = data[[self.budget_type, "R_before"]]
            X = pd.concat([t_r, hparams], axis=1).values

            if self.method == 'pb2mix':
                # elif self.cat_exp == 'cocabo':
                newpoint = dfnewpoint.iloc[-1, :][[self.budget_type, "R_before"] + self.cat_vars].values
                new, eps = self.select_config(X, y, current, newpoint, self.mutations, num_f=len(t_r.columns))
                for i, cont_idx in enumerate(self.cont_vars):
                    to_use[cont_idx] = new[i]
                to_use = list(to_use.values())
            else:
                newpoint = dfnewpoint.iloc[-1, :][[self.budget_type, "R_before"]].values
                new, eps = self.select_config(X, y, current, newpoint, self.mutations, num_f=len(t_r.columns))
                for i, cont_idx in enumerate(self.cont_vars):
                    to_use[cont_idx] = new[i]
                to_use = list(to_use.values())
            # there could be one hot dimensions -- do appropriate inverse transform here
            cs_array = np.nan * np.ones(len(self.env.config_space))
            to_use = np.array(to_use)
            for i, element in enumerate(to_use):
                dim_name = self.mutations['Name'].iloc[i]
                try:
                    cs_idx = self.env.config_space.get_idx_by_hyperparameter_name(dim_name)
                    cs_array[cs_idx] = element
                except IndexError:
                    # this is because of one-hot transformations -- find the matching group and do appropriate inverse transform
                    cs_idx = self.env.config_space.get_idx_by_hyperparameter_name(dim_name.split('_')[0])
                    if not np.isnan(cs_array[cs_idx]): continue
                    for one_hot_group in self.one_hot_dim_indices:
                        if i in one_hot_group:
                            cat_choice = np.argmax(to_use[one_hot_group])
                            cs_array[cs_idx] = cat_choice
                            break
            new_config = CS.Configuration(self.env.config_space, vector=to_use)
            source = 'bo'
        else:
            new_config = self.env.config_space.sample_configuration()
            new_config_array = new_config.get_array()
            to_use = new_config_array.tolist()
            if self.method == 'pb2mix':
                eps = [0, 0]
            else:
                eps = 0
            source = 'random'
        try:
            self.running[str(agent_t)].update({str(agent): to_use})
        except KeyError:
            self.running.update({str(agent_t): {str(agent): to_use}})

        # df_hparams = self.mutations.copy()
        # df_hparams['Use'] = to_use

        return new_config, eps, source

    def select_config(self, Xraw, yraw, current, newpoint, mutations, num_f):
        """Selects the next hyperparameter config to try.
        """
        Xraw = np.array(Xraw, dtype=float)
        yraw = np.array(yraw, dtype=float)
        current = np.array(current, dtype=float)
        newpoint = np.array(newpoint, dtype=float)

        oldpoints = Xraw[:, :num_f]
        X_use = Xraw[:, num_f:]

        if self.method == 'pb2mix':
            # if self.cat_exp == 'cocabo':
            cat_dims = [int(x[1]) for x in self.cat_vars]
            X_cat = X_use[:, cat_dims]
            X_use = X_use[:, [x for x in range(X_use.shape[1]) if x not in cat_dims]]
            fixed_cat = newpoint[num_f:]
            newpoint = newpoint[:num_f]
            current_cat = [[x for idx, x in enumerate(curr) if idx in cat_dims] for curr in current]
            current = [[x for idx, x in enumerate(curr) if idx not in cat_dims] for curr in current]
        else:
            cat_dims = []

        X_use = np.concatenate((oldpoints, X_use), axis=1)

        base_vals = [val for (val, x) in zip(mutations.Range.values, mutations.Type.values) if x is not 'categorical']
        base_vals = np.array(base_vals).T[::-1]

        fixed_points = np.concatenate((oldpoints, newpoint.reshape(1, -1)), axis=0)
        old_lims = np.concatenate((np.max(fixed_points, axis=0),
                                   np.min(fixed_points, axis=0))).reshape(2, oldpoints.shape[1])

        old_lims[0] -= 1e-8
        old_lims[1] += 1e-8

        limits = np.concatenate((old_lims, base_vals), axis=1)
        limits[0] -= 1e-8
        limits[1] += 1e-8

        X = normalize(X_use, limits)
        y = standardize(yraw).reshape(yraw.size, 1)

        fixed = normalize(newpoint, old_lims)

        if self.method == 'pb2mix':
            X = np.concatenate((X[:, :num_f], X_cat, X[:, num_f:]), axis=1)
            cat_locs = [x + num_f for x in range(X_cat.shape[1])]
            kernel = TV_MixtureViaSumAndProduct(X.shape[1],
                                                variance_1=1.,
                                                variance_2=1.,
                                                variance_mix=1.,
                                                lengthscale=1.,
                                                epsilon_1=0.,
                                                epsilon_2=0.,
                                                mix=0.5,
                                                cat_dims=cat_locs)
        else:
            current = [[x for idx, x in enumerate(entry) if idx not in cat_dims] for entry in current]
            kernel = TV_SquaredExp(
                input_dim=X.shape[1], variance=1., lengthscale=1., epsilon=0.1)
        try:
            m = GPy.models.GPRegression(X, y, kernel)
        except np.linalg.LinAlgError:
            # add diagonal ** we would ideally make this something more robust...
            X += np.eye(X.shape[0]) * 1e-3
            m = GPy.models.GPRegression(X, y, kernel)
        try:
            m.optimize()
        except np.linalg.LinAlgError:
            # add diagonal ** we would ideally make this something more robust...
            X += np.eye(X.shape[0]) * 1e-3
            m = GPy.models.GPRegression(X, y, kernel)
            m.optimize()

        # if self.cat_exp in ['random', 'exp3_indep', 'exp3_dep']:
        # m.kern.lengthscale.fix(m.kern.lengthscale.clip(1e-5, 1))
        # elif self.cat_exp == 'cocabo':
        m.kern.lengthscale.fix(m.kern.lengthscale.clip(1e-5, 1))

        if current is None:
            m1 = deepcopy(m)
        else:
            # add the current trials to the dataset
            current_use = normalize(current, base_vals)
            padding = np.array([fixed for _ in range(current_use.shape[0])])

            if self.method == 'pb2mix':
                current_use = np.concatenate((padding, current_cat, current_use), axis=1)
            else:
                current_use = np.hstack((padding, current_use))

            Xnew = np.vstack((X, current_use))

            # y value doesn't matter, only care about the variance.
            ypad = np.zeros(current_use.shape[0])
            ypad = ypad.reshape(-1, 1)
            ynew = np.vstack((y, ypad))

            if self.method == 'pb2mix':
                cat_dims = [int(x[1]) - 1 for x in self.cat_vars]
                kernel = TV_MixtureViaSumAndProduct(Xnew.shape[1],
                                                    variance_1=1.,
                                                    variance_2=1.,
                                                    variance_mix=1.,
                                                    lengthscale=1.,
                                                    epsilon_1=0.,
                                                    epsilon_2=0.,
                                                    mix=0.5,
                                                    cat_dims=cat_locs)
            else:
                cat_dims = []
                kernel = TV_SquaredExp(
                    input_dim=X.shape[1], variance=1., lengthscale=1., epsilon=0.1)
            m1 = GPy.models.GPRegression(Xnew, ynew, kernel)
            m1.optimize()

        if self.method == 'pb2mix':
            fixed = np.concatenate((fixed.reshape(1, -1), fixed_cat.reshape(1, -1)), axis=1)
            xt = optimize_acq(LCB, m, m1, fixed, num_f + len(cat_dims))
        else:
            xt = optimize_acq(LCB, m, m1, fixed, num_f)

        # convert back...
        xt = xt * (np.max(base_vals, axis=0) - np.min(base_vals, axis=0)) + np.min(
            base_vals, axis=0)
        xt = xt.astype(np.float32)

        if self.method == 'pb2mix':
            epsilon = [m.kern.epsilon_1[0], m.kern.epsilon_2[0]]
        else:
            epsilon = m.kern.epsilon[0]

        return (xt, epsilon)
